{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import torch.distributions\n",
    "from torch.distributions.binomial import Binomial\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Binomial Distribution\n",
    "\n",
    "Consider a dataset of photos that contains 20% of celebrity faces. Let us design an experiment where we select 3 photos from the dataset. The probability of success (picking up a celebrity photo) in a single trial is 0.2 `(p=0.2)`. Probability of failure (not picking up a celebrity photo) is `1 - p = 0.8`.\n",
    "\n",
    "Say, we do `n=3` trials. The binomial distribution gives us the probability of picking up `k = 0` or `k = 1` or `k = 2` or `k = 3` celebrity photos in the 3 trials.\n",
    "\n",
    "Formally, Binomial distribution is the probability distribution of `k` successes in `n` trials, where the probability of success in any given trial is known to be `(p)`.\n",
    "\n",
    "This is given by the formula \n",
    "                          $$ P(X=k) = {n \\choose k} p^k (1-p)^{ n-k} $$\n",
    "where $ {n \\choose k} = \\frac{n!}{k!(n-k)!}$ (this is the number of ways in which we can choose k slots from n possible slots).\n",
    "\n",
    "For our experiment, `n=3` and `p=0.2`. Substituting different values of `k` from 0 to 3, we have the following probability distribution\n",
    "<table>\n",
    "    <tr>\n",
    "        <td>P(X=0) <td>0.512\n",
    "    <tr>\n",
    "    <tr>\n",
    "        <td>P(X=1)<td>0.384\n",
    "    <tr>\n",
    "    <tr>\n",
    "        <td>P(X=2)<td>0.096\n",
    "    <tr>\n",
    "    <tr>\n",
    "        <td>P(X=3)<td>0.008\n",
    "    <tr>\n",
    "<table>\n",
    "    \n",
    "The mean of the distribution is given by n x p = 3 * 0.2 = 0.6\n",
    "\n",
    "This is demonstrated by the PyTorch code below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Indicates the number of times the coin is flipped in one experiment\n",
    "num_trials = 3 \n",
    "\n",
    "# Probability of success\n",
    "p = 0.2 \n",
    "\n",
    "# Instantiate the binomial distribution\n",
    "binomial_dist = Binomial(num_trials, probs=torch.Tensor([p]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Log Prob: -0.957\n",
      "Raw eval Log Prob: -0.957\n"
     ]
    }
   ],
   "source": [
    "# Instantiate single point test dataset\n",
    "X = torch.tensor([1], dtype=torch.float)\n",
    "\n",
    "def nCk(n, k):\n",
    "    f = math.factorial\n",
    "    return f(n) * 1. / (f(k) * f(n-k))\n",
    "\n",
    "# Function to evaluate log prob using math formula\n",
    "def raw_eval(X, n, p):\n",
    "    result = nCk(n, X) * (p ** X) * (1 - p) ** (n - X)\n",
    "    return torch.log(result)\n",
    "\n",
    "# Evaluate log-prob using PyTorch distributions function call\n",
    "log_prob = binomial_dist.log_prob(X)\n",
    "print(\"Log Prob: {:.3f}\".format(log_prob[0]))\n",
    "\n",
    "# Evaluate log-prob using formula\n",
    "raw_eval_log_prob = raw_eval(X, num_trials, p)\n",
    "print(\"Raw eval Log Prob: {:.3f}\".format(raw_eval_log_prob[0]))\n",
    "\n",
    "assert torch.isclose(log_prob, raw_eval_log_prob, atol=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of experiment runs (each with 3 trials)\n",
    "num_samples = 1000000 \n",
    "\n",
    "# Draw samples. Each element of the samples array represent the number of successes in that experiment.\n",
    "samples = binomial_dist.sample(torch.Size([num_samples]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample mean: 0.6000099778175354\n",
      "Dist Mean: tensor([0.6000])\n",
      "Sample variance: 0.4804621934890747\n",
      "Dist Variance: tensor([0.4800])\n"
     ]
    }
   ],
   "source": [
    "# The mean obtained from the samples. Denotes the average number of successes\n",
    "sample_mean = samples.mean()\n",
    "print(\"Sample mean: {}\".format(sample_mean))\n",
    "\n",
    "# The mean of the distribution from Pytorch\n",
    "dist_mean = binomial_dist.mean\n",
    "print(\"Dist Mean: {}\".format(dist_mean))\n",
    "\n",
    "# As expected, the two means approximately match and are equal to num_trials * p = 3 * .2 = 0.6\n",
    "assert torch.isclose(sample_mean, dist_mean, atol=0.2)\n",
    "\n",
    "# The variance obtained from the samples \n",
    "sample_var = binomial_dist.sample([num_samples]).var()\n",
    "print(\"Sample variance: {}\".format(sample_var))\n",
    "\n",
    "# The variance of the distribution from Pytorch\n",
    "dist_var = binomial_dist.variance\n",
    "print(\"Dist Variance: {}\".format(dist_var))\n",
    "\n",
    "# As expected, the two variances approximately match.\n",
    "assert torch.isclose(sample_var, dist_var, atol=0.2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 success: 51.2%\n",
      "1 success: 38.4%\n",
      "2 success: 9.6%\n",
      "3 success: 0.8%\n"
     ]
    }
   ],
   "source": [
    "# Let us now plot the distribution of the observed data\n",
    "dist = []\n",
    "for i in range(num_trials + 1):\n",
    "    fraction = float(torch.sum(samples == i)) / num_samples\n",
    "    dist.append((i, fraction))\n",
    "    print(\"{} success: {}%\".format(i, round(fraction * 100, 1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_binomial_dist(dist):\n",
    "    fig, ax = plt.subplots(figsize=(8, 6))\n",
    "    x, y = [[d[0] for d in dist], [d[1] for d in dist]]\n",
    "    ax.plot(x, y, 'bo')\n",
    "    ax.vlines(x, 0, y, colors='r', alpha=0.5)\n",
    "    \n",
    "    ax.set_ylim(0)\n",
    "    ax.set_xlabel(\"Number of Successes\", fontsize=12)\n",
    "    ax.set_ylabel(\"P(X=k)\", fontsize=12)\n",
    "    plt.xticks(fontsize=12)\n",
    "    plt.yticks(fontsize=12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfsAAAF9CAYAAAADYsEXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi40LCBodHRwOi8vbWF0cGxvdGxpYi5vcmcv7US4rQAAHNhJREFUeJzt3XucXWdd7/HPN00pTdKILaHIJYlAy6UeixpBKRWUgwJSW1/VI3RaLUIjrR68crzUIlR6Kiio51iBeFoKdSxobaVYOEfUA7aiYlAqDZcIpWmhLQ2EXiaB3vidP9bKyc64k8w0M2uSZz7v12u/Ztazn7X2bz+zku88a69ZK1WFJElq15KFLkCSJM0vw16SpMYZ9pIkNc6wlySpcYa9JEmNM+wlSWqcYS9JUuMMe0mSGmfYS5LUOMNekqTGLV3oAubKIx/5yFq7du1ClyFJ0iA++tGPfqmqVs2kbzNhv3btWjZu3LjQZUiSNIgkW2ba18P4kiQ1zrCXJKlxhr0kSY0z7CVJapxhL0lS4wx7SZIaZ9hLktQ4w16SpMYZ9pIkNc6wlySpcYb9NJOTsHYtLFnSfZ2cXOiKJEnaP81cG38uTE7C+vWwY0e3vGVLtwwwMbFwdUmStD+c2Y8499xdQb/Tjh1duyRJByvDfsTNN8+uXZKkg4FhP2L16tm1S5J0MDDsR1xwASxbtnvbsmVduyRJByvDfsTEBGzYAGuOmiIUa9Z0y56cJ0k6mA0W9kmOTHJVku1JtiQ5bQ/9Xpvk/iRTI48nDFXnxATc9DtX8PW3v4ObbjLoJUkHvyH/9O4i4D7gaODpwDVJrq+qTWP6vruqTh+wNkmSmjXIzD7JcuBU4Lyqmqqq64CrgTOGeH1JkhazoQ7jHws8UFWbR9quB47bQ/+TkmxLsinJ2XvaaJL1STYm2bh169a5rFeSpGYMFfYrgLuntd0FHDGm758CTwVWAWcBr0ny0nEbraoNVbWuqtatWrVqLuuVJKkZQ4X9FLByWttK4J7pHavqE1V1a1U9WFUfBn4f+JEBapQkqUlDhf1mYGmSY0bajgfGnZw3XQGZl6okSVoEBgn7qtoOXAmcn2R5khOAk4HLpvdNcnKSb0znGcCrgPcMUackSS0a8qI65wCHA3cAlwNnV9WmJCcmmRrp9xLgM3SH+N8JvKGq3jFgnZIkNWWwv7Ovqm3AKWPar6U7gW/n8tiT8SRJ0kPj5XIlSWqcYS9JUuMMe0mSGmfYS5LUOMNekqTGGfaSJDXOsJckqXGGvSRJjTPsJUlqnGEvSVLjDHtJkhpn2EuS1DjDXpKkxhn2kiQ1zrCXJKlxhr0kSY0z7CVJapxhL0lS4wx7SZIaZ9hLktQ4w16SpMYZ9pIkNc6wlySpcYa9JEmNM+wlSWqcYS9JUuMMe0mSGmfYS5LUOMNekqTGGfaSJDXOsJckqXGGvSRJjTPsJUlqnGEvSVLjDHtJkhpn2EuS1DjDXpKkxhn2kiQ1zrCXJKlxhr0kSY0z7CVJapxhL0lS4wx7SZIaZ9hLktQ4w16SpMYZ9pIkNc6wlySpcYa9JEmNM+wlSWqcYS9JUuMGC/skRya5Ksn2JFuSnLaP/g9L8skknx+qRkmSWrR0wNe6CLgPOBp4OnBNkuuratMe+r8a2AocMVB9kiQ1aZCZfZLlwKnAeVU1VVXXAVcDZ+yh/zcDpwMXDlGfJEktG+ow/rHAA1W1eaTteuC4PfT/n8CvAV+d78IkSWrdUGG/Arh7WttdjDlEn+SHgUOq6qp9bTTJ+iQbk2zcunXr3FQqSVJjhgr7KWDltLaVwD2jDf3h/jcCr5rJRqtqQ1Wtq6p1q1atmpNCJUlqzVAn6G0GliY5pqr+vW87Hph+ct4xwFrg2iQADwO+IcntwHdV1U3DlCtJUjsGCfuq2p7kSuD8JK+gOxv/ZOBZ07reADx+ZPlZwB8A3053Zr4kSZqlIS+qcw5wOHAHcDlwdlVtSnJikimAqnqgqm7f+QC2AV/vlx8csFZJkpox2N/ZV9U24JQx7dfSncA3bp0PAo+b38okSWqbl8uVJKlxhr00C5OTsHYtLFnSfZ2cXOiKJGnfhrxcrnRQm5yE9ethx45uecuWbhlgYmLh6pKkfXFmL83QuefuCvqdduzo2iXpQGbYSzN0882za5ekA4VhL83Q6tWza5ekA4VhL83QBRfAsmW7ty1b1rVL0oHMsJdmaGICNmyANUdNEYo1a7plT86TdKDzbHxpFiYmYOL+K7qFM89c0Fokaaac2UuS1DjDXpKkxhn2kiQ1zrCXJKlxhr0kSY0z7CVJapxhL0lS4wx7SZIaZ9hLktQ4w16SpMYZ9pIkNc6wlySpcYa9JEmNM+wlSWqcYS9JUuMMe0mSGmfYS5LUOMNekqTGGfaSJDXOsJckqXGGvSRJjTPsJUlqnGEvSVLjDHtJkhpn2EuS1DjDXpKkxhn2kiQ1zrCXJKlxhr0kSY0z7CVJapxhL0lS4wx7SZIaZ9hLktQ4w16SpMYZ9pIkNc6wlySpcYa9JEmNM+wlSWqcYS9JUuMMe0mSGmfYS5LUOMNekqTGDRb2SY5MclWS7Um2JDltD/1+PsmNSe5OcmuS302ydKg6JUlqzZAz+4uA+4CjgQngLUmOG9PvauDbq2ol8C3A8cCrBqtSkqTGDBL2SZYDpwLnVdVUVV1HF+pnTO9bVZ+tqjt3rgp8HXjSEHVKktSioWb2xwIPVNXmkbbrgXEze5KcluRu4Et0M/u3zX+JkiS1aaiwXwHcPa3tLuCIcZ2r6k/6w/jHAm8FvjiuX5L1STYm2bh169a5rFeSpGYMFfZTwMppbSuBe/a2UlX9O7AJ+MM9PL+hqtZV1bpVq1bNSaGSJLVmRme5J3kU8AN0h9QfAdxJdxj+A1V1+ww2sRlYmuSYPsDpt7VphjU+cSZ1SpKk/2ivM/skT01yBfBJupPpDgVu77+eAWxKckWSp+1tO1W1HbgSOD/J8iQnACcDl415zVf0v1zQb/dXgb+Z9TuTJEnAvmf2lwK/DUxU1b3Tn0xyGPBDwMXAd+9jW+cAlwB3AF8Gzq6qTUlOBN5fVSv6ficAFyRZAWwF/gw4b2ZvR5IkTbfXsK+qZ+7j+XvpwvjP9vVCVbUNOGVM+7V0J/DtXH7ZvrYlSZJmbsYn6CX5sT20v27uypEkSXNtNmfjX5jkhaMNSS6kO4wvSZIOULMJ+x8E3tp/xk6SNwPPB75vPgqTJElzY8Y3mKmqTyb5YeA9Sf4eWA18X1VNv1iOJEk6gOw17JOMm7VfDPwU8EpgXRKq6m/nozhJkrT/9jWzv3gP7V8Dfq//voAnzFlFkiRpTu3rT+++eahCJEnS/BjyfvaSJGkBzCrsk1w98v01c1+OJEmaa7Od2T975PsT57IQSZI0P2Yb9pmXKiRJ0ryZbdjXvFQhSZLmjTN7SZIa58xekqTG7c/M3lm+JEkHgdmG/XUj3187l4VIkqT5Mauwr6qTRr5/0dyXI0mS5to+wz7J6n08/7y5K0eSJM21mczsr0/y09MbkzwiyaXA5JxXJUmS5sxMwv4/A2cluTbJMQBJfgz4VL/+cfNYnyRJ2k/7usUtVfXRJOuAc4F/TvJx4HHAy6rq/fNdoCRJ2j8zOkGvqh4AbqP7O/unABuBf5rHuiRJ0hyZyQl6T0jyt8DPAS8CngjcCXwiyUvmuT5JkrSfZjKz/1fgw8DTq+ofquruqjoLOB24IMl757VCSZK0X2YS9s+pql+vqvtGG6vqr4FvBT47L5VJkqQ5MZMT9D62l+e20x3elyRJB6i9zuyTvDnJo/fR59FJ3jy3ZUmSpLmyr5n9p4GPJPkk8KF++R7gCOBY4LnAk4HXz2ONkiRpP+w17KvqbUkuAU4GXgicAjwC+Arwb8Bbgff2f5onSZIOQPv8zB44FPh2YBXwfuDCqvravFYlSZLmzEzOxr8IOInu8rinAr89rxVJkqQ5NZOwfwHw/VX13+gO5b94fkuSJElzaSZhv7yqbgOoqluAb5jfkiRJ0lyayWf2S5N8L5A9LFNVfzsfxUmSpP03k7C/A7hkZPnL05YLeMJcFiVJkubOTK6gt3aAOiRJ0jyZ0S1uJUnSwcuwlySpcYa9JEmNM+wlSWqcYS9JUuMMe0mSGmfYS5LUOMNekqTGGfaSJDXOsJckqXGGvSRJjTPsJUlqnGEvSVLjDHtJkhpn2EuS1LjBwj7JkUmuSrI9yZYkp+2h36uT3JDkniSfS/LqoWqUJKlFSwd8rYuA+4CjgacD1yS5vqo2TesX4MeBfwOeCPxVkluq6l0D1ipJUjMGmdknWQ6cCpxXVVNVdR1wNXDG9L5V9caq+peqeqCqPg28BzhhiDolSWrRUIfxjwUeqKrNI23XA8ftbaUkAU4Eps/+dz6/PsnGJBu3bt06Z8VKktSSocJ+BXD3tLa7gCP2sd5r6Wp8+7gnq2pDVa2rqnWrVq3a7yIlSWrRUJ/ZTwErp7WtBO7Z0wpJfobus/sTq+reeaxNkqSmDTWz3wwsTXLMSNvx7Pnw/E8CvwI8r6o+P0B9kiQ1a5Cwr6rtwJXA+UmWJzkBOBm4bHrfJBPAfweeX1U3DlGfJEktG/KiOucAhwN3AJcDZ1fVpiQnJpka6fd64Cjgn5NM9Y+3DlinJElNGezv7KtqG3DKmPZr6U7g27n8zUPVJEnSYuDlciVJapxhL0lS4wx7SZIaZ9hLktQ4w16SpMYZ9pIkNc6wlySpcYa9JEmNM+wlSWqcYS9JUuMMe0mSGmfYS5LUOMNekqTGGfaSJDXOsJckqXGGvSRJjTPsJUlqnGEvSVLjDHtJkhpn2EuS1DjDXpKkxhn2kiQ1zrCXJKlxhr0kSY0z7CVJapxhL0lS4wx7SZIaZ9hLktQ4w16SpMYZ9pIkNc6wlySpcYa9JEmNM+wlSWqcYS9JUuMMe0mSGmfYS5LUOMNekqTGGfaSJDXOsJckqXGGvSRJjTPsJUlqnGEvSVLjDHtJkhpn2EuS1DjDXpKkxhn2kiQ1zrCXJKlxhr0kSY0z7CVJapxhL0lS4wYL+yRHJrkqyfYkW5Kctod+35vk/ya5K8lNQ9UnSVKrhpzZXwTcBxwNTABvSXLcmH7bgUuAVw9YmyRJzRok7JMsB04Fzquqqaq6DrgaOGN636r6SFVdBtw4RG2SJLVuqJn9scADVbV5pO16YNzMXpIkzaGhwn4FcPe0truAI/Zno0nWJ9mYZOPWrVv3Z1OStF8mJ2HtWliypPs6ObnQFUm7DBX2U8DKaW0rgXv2Z6NVtaGq1lXVulWrVu3PpiTpIZuchPXrYcsWqOq+rl9v4OvAMVTYbwaWJjlmpO14YNNAry9J8+bcc2HHjt3bduzo2qUDwSBhX1XbgSuB85MsT3ICcDJw2fS+SZYkeThwaLeYhyd52BB1StJDcfPNs2uXhjbkn96dAxwO3AFcDpxdVZuSnJhkaqTf9wBfBd4HrO6//6sB65SkWVm9enbt0tCWDvVCVbUNOGVM+7V0J/DtXP4gkKHqkqT9dcEF3Wf0o4fyly3r2qUDgZfLlaT9NDEBGzbAmqOmCMWaNd3yxMRCVyZ1BpvZS1LLJiZg4v4ruoUzz1zQWqTpnNlLktQ4w16SpMYZ9pIkNc6wlySpcYa9JEmNM+wlSWqcYS9JUuMMe0mSGmfYS5LUOMNekqTGGfaSJDXOsJckqXGGvSRJjTPsJUlqnGEvSVLjDHtJkhpn2EuS1DjDXpKkxhn2kiQ1zrCXJKlxhr0kSY0z7CVJapxhL0lS4wx7SZIaZ9hLktQ4w16SpMYZ9pIkNc6wlySpcYa9JEmNM+wlSWqcYS9JUuMMe0mSGmfYS5LUOMNekqTGGfaSJDXOsJckqXGGvSRJjTPsJUlqnGEvSVLjDHtJkhpn2EuS1DjDXpKkxhn2kiQ1zrCXJKlxhr0kSY0z7CVJmkeTk7B2LSxZ0n2dnBy+hqXDv6QkSYvD5CSsXw87dnTLW7Z0ywATE8PV4cxekqR5cu65u4J+px07uvYhGfaSJM2Tm2+eXft8GSzskxyZ5Kok25NsSXLaHvolyRuSfLl/vCFJhqpTkqS5snr17Nrny5Az+4uA+4CjgQngLUmOG9NvPXAKcDzwrcBJwE8NVaQkSXPlggtg2bLd25Yt69qHNEjYJ1kOnAqcV1VTVXUdcDVwxpjuPwG8qao+X1VfAN4EnDlEnZIkzaWJCdiwAdYcNUUo1qzploc8OQ+GOxv/WOCBqto80nY98JwxfY/rnxvtN+4IgCRJB7yJCZi4/4pu4cwzF6SGVNX8v0hyIvBnVfXokbazgImqeu60vg8Cx1XVp/rlY4DNwJKaVmyS9XSH/QGeDHx6Dst+JPClOdzewc7x2J3jsYtjsTvHY3eOxy5zPRZrqmrVTDoONbOfAlZOa1sJ3DODviuBqelBD1BVG4ANc1XkqCQbq2rdfGz7YOR47M7x2MWx2J3jsTvHY5eFHIuhTtDbDCztZ+k7HQ9sGtN3U//cvvpJkqQZGCTsq2o7cCVwfpLlSU4ATgYuG9P9ncAvJHlskscAvwhcOkSdkiS1aMg/vTsHOBy4A7gcOLuqNiU5McnUSL+3Ae8FPg7cAFzTtw1tXj4eOIg5HrtzPHZxLHbneOzO8dhlwcZikBP0JEnSwvFyuZIkNc6wlySpcYs67L1e/y6zGIvXJrk/ydTI4wlD1zvfkvxMko1J7k1y6T76/nyS25PcneSSJIcNVOYgZjoWSc5M8uC0feO5w1U6/5IcluTi/t/IPUk+luSFe+nf+r4x4/FYDPsHQJI/TnJb/zPfnOQVe+k72P6xqMMer9c/aqZjAfDuqlox8rhxsCqHcyvweuCSvXVK8gPArwDPA9YATwBeN+/VDWtGY9H7h2n7xgfnt7TBLQVuobv65zcAvw78aZK10zsukn1jxuPRa33/ALgQWFtVK4EfAl6f5Dumdxp6/1i0Ye/1+neZ5VgsClV1ZVX9BfDlfXT9CeDiqtpUVV8BfpOG9g2Y1Vg0r6q2V9Vrq+qmqvp6Vf0l8DngP/xnzuLYN2YzHotC//O+d+di/3jimK6D7h+LNuzZ8/X6x81mW79e/2zGAuCkJNuSbEpy9vyXd0Abt28cneSoBapnoX1bki/1hy/PSzLUVToXRJKj6f79jLvw16LbN/YxHrBI9o8kf5hkB/Ap4DbgfWO6Dbp/LOawXwHcPa3tLuCIPfS9a1q/FQ19bj+bsfhT4KnAKuAs4DVJXjq/5R3Qxu0bMH7sWvd3wLcAj6I7UvRS4NULWtE8SnIoMAm8Y+e9PKZZVPvGDMZj0ewfVXUO3c/5RLoLyt07ptug+8diDvt5uV7/QWrGY1FVn6iqW6vqwar6MPD7wI8MUOOBaty+AeP3o6ZV1Y1V9bn+cO7HgfNpdN9IsoTuCqD3AT+zh26LZt+YyXgspv0DoP8/8jrgccC4I6CD7h+LOey9Xv8usxmL6Qpo5QjHQzFu3/hiVS36z7dpdN/oj+hdTHcy66lVdf8eui6KfWMW4zFdk/vHGEsZ/5n9oPvHog17r9e/y2zGIsnJSb6x/3PEZwCvAt4zbMXzL8nSJA8HDgEOSfLwPXy++E7g5UmeluQRdGcjXzpgqfNupmOR5IX9Z7YkeQpwHg3uG8Bb6D7KOqmqvrqXfs3vG70Zjcdi2D+SPCrJS5KsSHJIf8b9S4G/GdN92P2jqhbtAzgS+AtgO3AzcFrffiLdYfqd/QK8EdjWP95If6nhVh6zGIvL6c7KnqI7+eRVC137PI3Ha9l1Ju3Ox2uB1f17Xz3S9xeAL9Kd9/B24LCFrn8hxgL4nX4ctgM30h2mPXSh65/jsVjTv/+v9e9952Nike4bMx6PRbJ/rAI+BNzZ/8w/DpzVP7eg+4fXxpckqXGL9jC+JEmLhWEvSVLjDHtJkhpn2EuS1DjDXpKkxhn2kiQ1zrCXDkJJLk3y+gV67SR5e5KvJPnIQtQgaXYMe2kOJLkpyR397YJ3tr0iyQcXsKz58mzg+cDjquoZ059M8rAkb0ry+SRT/dj83vBlStrJsJfmziHAzy50EbOV5JBZrrIGuKm6yyyP86vAOuAZdHfwei7wLw+5QEn7zbCX5s5vA7/UX+d6N0nWJqnRa8on+WCSV/Tfn5nk75P8bpI7k9yY5Fl9+y39UYOfmLbZRyb5QJJ7knwoyZqRbT+lf25bkk8n+S8jz12a5C1J3pdkO/C9Y+p9TJKr+/U/k+Ssvv3lwP8Cvruftb9uzDh8J3BVdXdHrKq6qareObLtSvKkafW8fmT55CQfS3J3ks8meUHffmT/8cGt/UcIfzGyzov7de5M8uEk3zry3C8n+UI/Tp9O8ry+/RlJNvav88Ukbx5Z57v67dyZ5Pokzx157sz+53NPks8lmRgzBtIBZdyNPSQ9NBuBDwK/RHdTi9l6Jl2QHgW8DngX8F7gScBzgD9P8udVNdX3nwB+EPgnuvs1TALP7j9K+ADwGuCFwH8CPpDkhqr6RL/uacCLgBcDDxtTy7uAG4DHAE/p1/9sVV2c5EHgFVX17D28j3+ku3HUfcC1wA01w+typ7u50jvpbn36N8A3sev+3pfRXVv8uP7rs/p1vg24BDiJ7mdwOnB1kicDa+luufqdVXVrkrV0R2Cguz3z71fVZUlW0N1rnSSPBa4BzgD+N/A8urF/CrAD+B/99j6d5Jvo7ishHdCc2Utz6zXAf02y6iGs+7mqentVPQi8G3g8cH5V3VtVf0V3r/AnjfS/pqr+rqruBc6lm20/ni7Ab+q39UBV/Svw58CPjqz7nqr6++ruLf610SL6bZwA/HJVfa2qPkb3S8iPz/B9XAi8ge6XkY3AF8YcldiTlwOXVNUH+tq+UFWf6kP1hcArq+orVXV/VX2oX2c98Laq+qfq7iH+DuBe4LuAB4HDgKclObQ/yvDZfr37gScleWRVTVXVP/btpwPvq6r39TV8oH8fL+qf/zrwLUkOr6rbqqql212rUYa9NIeq6gbgL4FfeQirf3Hk+6/225vetmJk+ZaR152iuyPjY+g+U39mfwj6ziR30gXvo8etO8ZjgG1Vdc9I2xbgsTN5E33gXlRVJwCPAC4ALkny1Bms/njgs3to31ZVXxnz3BrgF6e938cDj6mqzwA/R3eXvjuSvCvdbaqh+8XiWOBTSf45yYtHtvej07b3bOCb+vMUfgx4JXBbkmv6Gb90QDPspbn3G8BZ7B6OO09mWzbSNhq+D8Xjd37TH4Y+EriVLsg/VFWPGHmsqKqzR9bd22H1W4Ejkxwx0rYa+MJsC6yqr1bVRcBXgKf1zTvY8zjcAjxxzKZu6Wv6D+dD9M9dMO39Lquqy/sa/qT/yGHn7Vjf0Lf/e1W9FHhU33ZF/xHILcBl07a3vKp+q1/v/1TV8+k+YvgU8EezHRdpaIa9NMf62eS7gVeNtG2lC8vTkxyS5CcZH2qz8aIkz07yMOA3gX+sqlvojiwcm+SMJIf2j++c4cyafhsfBi5M8vD+ZLeXA388k/WT/FyS5yY5PMnS/hD+EcC/9l0+BpzWj8ML6M5H2Oli4GVJnpdkSZLHJnlKVd0GvB/4wyTf2L+n7+nX+SPglUmemc7yJD+Y5IgkT07yfUkOo7vn+lfpDsOT5PQkq6rq63T3H6d/7o+Bk5L8QF/jw/v387gkR/cnEC6n+6hgauf2pAOZYS/Nj/OB5dPazgJeDXyZ7iSzD+/na/wJ3VGEbcB30H3WTH/4/fuBl9DN0m+nm7keNottv5Tu5LZbgauA36iqv57hujuAN/Wv+yXgp4FTq+rG/vmfpTuZbufHC///rPqq+gjwMuB3gbuAD9HNyKE7Ye5+utn0HXSH56mqjXRj+wd0RxA+A5zZr3MY8Ft9HbfTzeJ/tX/uBcCmJFN0J+u9pD8ScQtwMvBrwFa6mf6r6f6/XAL8Qj8u2+h+URk9YiIdkDLDk2QlSdJBypm9JEmNM+wlSWqcYS9JUuMMe0mSGmfYS5LUOMNekqTGGfaSJDXOsJckqXGGvSRJjft/xHko3zTw12EAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 576x432 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_binomial_dist(dist)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Interactive Visualization\n",
    "\n",
    "Here we allow the user to set different values for the $p$ and the number of trials for a Binomial distribution and visualise the resulting probability distribution function\n",
    "\n",
    "Note: In order to run this section, please download the notebook. Interactive snippets do not work online. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "73b916e8a3fe41748b061e87ecad9b97",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to  previous…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "89d104ce21ad42679ce48a56b21a0adb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(FloatSlider(value=0.5, description='p', max=1.0), IntText(value=3, description='num_tria…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib widget\n",
    "import ipywidgets as widgets\n",
    "from ipywidgets import interact, interact_manual\n",
    "\n",
    "fig_1, ax_1 = plt.subplots(nrows=1, ncols=1)\n",
    "\n",
    "@interact\n",
    "def plot_binomial_distribution(p=(0.0, 1.0), num_trials=widgets.IntText(3)):\n",
    "    def _reset_plot(ax):\n",
    "        ax.clear()\n",
    "        ax.set_ylim(0)\n",
    "        ax.set_title(\"Binomial Distribution\")\n",
    "        ax.set_xlabel(\"Number of Successes\", fontsize=12)\n",
    "        ax.set_ylabel(\"P(X=k)\", fontsize=12)\n",
    "        \n",
    "    num_samples = 1000\n",
    "    binomial_dist = Binomial(num_trials, probs=torch.Tensor([p]))\n",
    "    samples = binomial_dist.sample([num_samples])\n",
    "    dist = []\n",
    "    for i in range(num_trials + 1):\n",
    "        fraction = float(torch.sum(samples == i)) / num_samples\n",
    "        dist.append((i, fraction))\n",
    "    x, y = [[d[0] for d in dist], [d[1] for d in dist]]\n",
    "    _reset_plot(ax_1)\n",
    "    ax_1.plot(x, y, 'bo')\n",
    "    ax_1.vlines(x, 0, y, colors='r', alpha=0.5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
